import euclideanDistance from "./euclidean_distance.js";
import makeMatrix from "./make_matrix.js";
import max from "./max.js";

/**
 * 计算聚类数据的[轮廓系数](https://en.wikipedia.org/wiki/Silhouette_(clustering))
 *
 * @param {Array<Array<number>>} points N维数据点的坐标数组
 * @param {Array<number>} labels 数据点标签数组，长度必须与points一致，
 * 且取值范围为[0..G-1]，其中G为分组总数
 * @return {Array<number>} 各数据点的轮廓系数值
 *
 * @example
 * silhouette([[0.25], [0.75]], [0, 0]); // => [1.0, 1.0]
 */
function silhouette(points, labels) {
    if (points.length !== labels.length) {
        throw new Error("标签数量必须与数据点数量严格一致");
    }
    const groupings = createGroups(labels);
    const distances = calculateAllDistances(points);
    const result = [];
    for (let i = 0; i < points.length; i++) {
        let s = 0;
        if (groupings[labels[i]].length > 1) {
            const a = meanDistanceFromPointToGroup(
                i,
                groupings[labels[i]],
                distances
            );
            const b = meanDistanceToNearestGroup(
                i,
                labels,
                groupings,
                distances
            );
            s = (b - a) / Math.max(a, b);
        }
        result.push(s);
    }
    return result;
}

/**
 * 创建组ID到点ID的查找表
 *
 * @private
 * @param {Array<number>} labels 数据点标签数组，长度必须与points一致，
 * 且取值范围为[0..G-1]，其中G为分组总数
 * @return {Array<Array<number>>} 长度G的数组，每个元素为对应组内数据点索引的数组
 */
function createGroups(labels) {
    const numGroups = 1 + max(labels);
    const result = Array(numGroups);
    for (let i = 0; i < labels.length; i++) {
        const label = labels[i];
        if (result[label] === undefined) {
            result[label] = [];
        }
        result[label].push(i);
    }
    return result;
}

/**
 * 创建全量点间距离查找表
 *
 * @private
 * @param {Array<Array<number>>} points N维数据点的坐标数组
 * @return {Array<Array<number>>} 对称方阵形式的点间距离矩阵（主对角线为零）
 */
function calculateAllDistances(points) {
    const numPoints = points.length;
    const result = makeMatrix(numPoints, numPoints);
    for (let i = 0; i < numPoints; i++) {
        for (let j = 0; j < i; j++) {
            result[i][j] = euclideanDistance(points[i], points[j]);
            result[j][i] = result[i][j];
        }
    }
    return result;
}

/**
 * 计算当前点到最近组（由最近邻点确定）的平均距离
 *
 * @private
 * @param {number} which 当前点索引
 * @param {Array<number>} labels 数据点标签数组
 * @param {Array<Array<number>>} groupings 组结构数组，每个元素为对应组内数据点索引的数组
 * @param {Array<Array<number>>} distances 对称方阵形式的点间距离矩阵
 * @return {number} 当前点到最近组的平均距离
 */
function meanDistanceToNearestGroup(which, labels, groupings, distances) {
    const label = labels[which];
    let result = Number.MAX_VALUE;
    for (let i = 0; i < groupings.length; i++) {
        if (i !== label) {
            const d = meanDistanceFromPointToGroup(
                which,
                groupings[i],
                distances
            );
            if (d < result) {
                result = d;
            }
        }
    }
    return result;
}

/**
 * 计算点到指定组的平均距离（可包含自身所在组）
 *
 * @private
 * @param {number} which 当前点索引
 * @param {Array<number>} group 目标组的数据点索引数组
 * @param {Array<Array<number>>} distances 对称方阵形式的点间距离矩阵
 * @return {number} 当前点到目标组的平均距离
 */
function meanDistanceFromPointToGroup(which, group, distances) {
    let total = 0;
    for (let i = 0; i < group.length; i++) {
        total += distances[which][group[i]];
    }
    return total / group.length;
}

export default silhouette;
